
from pyspark.sql import *
from pyspark.sql.types import *
import pyspark.sql.functions as f

if __name__ == '__main__':
    # spark=SparkSession.builder.appName("test16_udf_define").master("local[*]").getOrCreate()
    # sc=spark.sparkContext

    spark=SparkSession.builder.appName("test1_dataFrame_create")\
        .master("local[*]").getOrCreate()
    sc=spark.sparkContext
    rdd=sc.parallelize([1,2,3,4,5,6,7]).map(lambda t:[t])
    df=spark.createDataFrame(rdd,["num"])
    df.show()

    def num_ride_10(num:int) -> int:
        return num*10

    #     todo 注册方式1    DSL+SQL
    #     注册udf的名称，这个utf名称用于SQL风格
    #     udf的处理逻辑，是一个函数
    #     udf的返回值类型，一定要和函数的返回值相同
    #     返回值对象，这是一个udf对象，用于DSL语法
    # 这种自定义方式，可以通过参数1用于SQL风格，通过返回值对象用于DSL风格
    udf2=spark.udf.register("udf1",num_ride_10,IntegerType())
    # <class 'function'>
    print(type(udf2))

    # DSL风格使用
    # 只可以接受列名，返回值UDF对象，如果作为方法使用，传入的的参数，一定是Column对象
    df.select(udf2(df["num"])).show()

    # SQL 风格使用
    # 接受SQL expressions
    df.selectExpr("udf1(num)").show()

    #     todo 注册方式2    DSL
    udf3=f.udf(num_ride_10,IntegerType())
    print(type(udf3))

    df.select(udf3(df["num"])).show()
    # 报错 不能作为SQL风格使用
    # df.selectExpr("udf3(num)").show()







